{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a Model, Evaluate, and Use in the Search Engine\n",
    "\n",
    "In these last few steps we train a model using the pairwise training set generated [in the previous step](3.pairwise-transform.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import groupby\n",
    "import numpy\n",
    "import random\n",
    "import sys\n",
    "sys.path.append('../..')\n",
    "from aips import *\n",
    "import json\n",
    "\n",
    "engine = get_engine()\n",
    "tmdb_collection = engine.get_collection(\"tmdb\")\n",
    "ltr = get_ltr_engine(tmdb_collection)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reload judgments & training set\n",
    "\n",
    "Load the dataset generated [from the previous section](3.pairwise-transform.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parsing QID 100\n"
     ]
    }
   ],
   "source": [
    "from ltr.judgments import judgments_open\n",
    "\n",
    "predictor_deltas = numpy.load(\"data/predictor_deltas.npy\")\n",
    "feature_data = numpy.load(\"data/feature_data.npy\")\n",
    "\n",
    "std_devs = feature_data[-1]\n",
    "means = feature_data[-2]\n",
    "feature_deltas = feature_data[:-2]\n",
    "\n",
    "normed_judgments = []\n",
    "with judgments_open(\"data/normed_judgments.txt\") as judg_list:\n",
    "    for j in judg_list:\n",
    "        normed_judgments.append(j)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 10.12\n",
    "\n",
    "Train the model with the fully transformed dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.37611453, 0.26665264, 0.09971659])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "random.seed(0)\n",
    "\n",
    "from sklearn import svm\n",
    "model = svm.LinearSVC(max_iter=10000)\n",
    "model.fit(feature_deltas, predictor_deltas)\n",
    "display(model.coef_[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A few sample features (omitted from book)\n",
    "\n",
    "Gathering features from a few movies \"Star Trek II: The Wrath of Khan\" and \"Star Trek III: Search for Spock\" to kick the tires of our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[-0.4300788728590096, -0.4553915319569039, -0.44511783129197596]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# If you wanted to confirm Wrath of Khans features\n",
    "\n",
    "ids = [\"154\"] #social network graded documents\n",
    "options = {\"keywords\": \"wrath of khan\"}\n",
    "response = ltr.get_logged_features(\"movie_model\", ids, options=options, fields=[\"id\", \"title\"])\n",
    "\n",
    "# Features from the search engine\n",
    "# Wrath of Khan\n",
    "wok_features = [5.9217176, 3.401492, 1982.0]\n",
    "# Search For Spock\n",
    "spock_features = [0.0,0.0,1984.0]\n",
    "\n",
    "# Wrath of Khan\n",
    "normed_wok_features = [0, 0, 0]\n",
    "for idx, f in enumerate(wok_features):\n",
    "    normed_wok_features[idx] = (f - means[idx]) / std_devs[idx]\n",
    "\n",
    "normed_spock_features = [0, 0, 0]\n",
    "for idx, f in enumerate(spock_features):\n",
    "    normed_spock_features[idx] = (f - means[idx]) / std_devs[idx]\n",
    "    \n",
    "display(normed_spock_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Taking the model for test drive... (omitted from book)\n",
    "\n",
    "Here we score a few documents with the model. This code is omitted from the book, but is explored in section 10.6.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.3050363943771231"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def score_one(features, model):\n",
    "    score = 0.0\n",
    "    for idx, f in enumerate(features):\n",
    "        this_coef = model.coef_[0][idx].item()\n",
    "        score += f * this_coef\n",
    "    \n",
    "    return score\n",
    "\n",
    "def rank(query_judgments, model):\n",
    "    for j in query_judgments:\n",
    "        j.score = score_one(j.features, model)\n",
    "    \n",
    "    return sorted(query_judgments, key=lambda j: j.score, reverse=True)\n",
    "\n",
    "score_one(normed_spock_features, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Wrath of Khan should score higher"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.4957538659013554"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "score_one(normed_wok_features, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 10.13 Test Training Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_qids = list(set([j.qid for j in normed_judgments]))\n",
    "random.shuffle(all_qids)\n",
    "proportion_train = 0.1\n",
    "\n",
    "split_index = int(len(all_qids) * proportion_train)\n",
    "test_qids = all_qids[:split_index]\n",
    "train_qids = all_qids[split_index:]\n",
    "\n",
    "train_data, test_data= [], []\n",
    "for j in normed_judgments:\n",
    "    if j.qid in train_qids:\n",
    "        train_data.append(j)\n",
    "    elif j.qid in test_qids:\n",
    "        test_data.append(j)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repeated from earlier - parwise transform\n",
    "\n",
    "You've already seen this code in the third notebook, so you can move on. We just need it here to do a pairwise_transform of the training data to train a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy\n",
    "\n",
    "def pairwise_transform(normed_judgments):\n",
    "        \n",
    "    from itertools import groupby\n",
    "    predictor_deltas = []\n",
    "    feature_deltas = []\n",
    "    \n",
    "    # For each query's judgments\n",
    "    for qid, query_judgments in groupby(normed_judgments, key=lambda j: j.qid):\n",
    "\n",
    "        # Annoying issue consuming python iterators, we ensure we have two\n",
    "        # full copies of each query's judgments\n",
    "        query_judgments_copy_1 = list(query_judgments) \n",
    "        query_judgments_copy_2 = list(query_judgments_copy_1)\n",
    "\n",
    "        # Examine every judgment combo for this query, \n",
    "        # if they're different, store the pairwise difference:\n",
    "        # +1 if judgment1 more relevant\n",
    "        # -1 if judgment2 more relevant\n",
    "        for judgment1 in query_judgments_copy_1:\n",
    "            for judgment2 in query_judgments_copy_2:\n",
    "                \n",
    "                j1_features=numpy.array(judgment1.features)\n",
    "                j2_features=numpy.array(judgment2.features)\n",
    "                \n",
    "                if judgment1.grade > judgment2.grade:\n",
    "                    predictor_deltas.append(+1)\n",
    "                    feature_deltas.append(j1_features - j2_features)\n",
    "                elif judgment1.grade < judgment2.grade:\n",
    "                    predictor_deltas.append(-1)\n",
    "                    feature_deltas.append(j1_features - j2_features)\n",
    "\n",
    "    # For training purposes, we return these as numpy arrays\n",
    "    return numpy.array(feature_deltas), numpy.array(predictor_deltas)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 10.14 - train on just train data\n",
    "\n",
    "We repeat the model training process only on the train subset of the queries. Notice because our test/training split is at the query level we repeat the pairwise transform we did earlier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LibLinear]"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([0.34455363, 0.26599357, 0.08024811])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_feature_deltas, train_predictor_deltas = pairwise_transform(train_data)\n",
    "\n",
    "from sklearn import svm\n",
    "model = svm.LinearSVC(max_iter=10000, verbose=1)\n",
    "model.fit(train_feature_deltas, train_predictor_deltas)\n",
    "display(model.coef_[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 10.15 - eval model on test data\n",
    "\n",
    "Here we compute a simple precision metric (proportion of relevant results in top N) averaged over all the test data. It's important to note this is not a very robust statistical analysis of the model's quality, we would want to perform multiple test-training samples and perform statistical significance testing between this experiment and a baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.28\n"
     ]
    }
   ],
   "source": [
    "def evaluate_model(test_data, model, k=5):\n",
    "    total_precision = 0\n",
    "    unique_queries = groupby(test_data, lambda j: j.qid)\n",
    "    num_groups = 0\n",
    "    for qid, query_judgments in unique_queries:\n",
    "        num_groups += 1\n",
    "        ranked = rank(list(query_judgments), model)\n",
    "        total_relevant = len([j for j in ranked[:k] if j.grade == 1])\n",
    "        total_precision += total_relevant / float(k)\n",
    "    return total_precision / num_groups\n",
    "\n",
    "evaluation = evaluate_model(test_data, model)\n",
    "print(evaluation)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listing 10.16 - A search engine LTR model\n",
    "\n",
    "This turns the model into one usable by the search engine\n",
    "\n",
    "- The weights for each (normalized) feature\n",
    "- The means to use to normalize each feature\n",
    "- The std deviation used to normalize each feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'store': 'movie_model',\n",
       " 'class': 'org.apache.solr.ltr.model.LinearModel',\n",
       " 'name': 'movie_model',\n",
       " 'features': [{'name': 'title_bm25',\n",
       "   'norm': {'class': 'org.apache.solr.ltr.norm.StandardNormalizer',\n",
       "    'params': {'avg': '0.7240581864747018', 'std': '1.683547442499148'}}},\n",
       "  {'name': 'overview_bm25',\n",
       "   'norm': {'class': 'org.apache.solr.ltr.norm.StandardNormalizer',\n",
       "    'params': {'avg': '0.6904480748206504', 'std': '1.5161636226604036'}}},\n",
       "  {'name': 'release_year',\n",
       "   'norm': {'class': 'org.apache.solr.ltr.norm.StandardNormalizer',\n",
       "    'params': {'avg': '1993.045660222131', 'std': '20.32194530575292'}}}],\n",
       " 'params': {'weights': {'title_bm25': 0.34455363381155396,\n",
       "   'overview_bm25': 0.2659935690511117,\n",
       "   'release_year': 0.08024810647908147}}}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_name = \"movie_model\"\n",
    "feature_names = [\"title_bm25\", \"overview_bm25\", \"release_year\"]\n",
    "linear_model = ltr.generate_model(model_name, feature_names,\n",
    "                                  means, std_devs, model.coef_[0])\n",
    "response = ltr.upload_model(linear_model)\n",
    "display(linear_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 10.17 - Search with the trained LTR model\n",
    "\n",
    "Executing a search with the LTR model reranking (expensive)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Returned Documents:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'id': '570724', 'title': 'The Story of Harry Potter', 'score': 2.1900952},\n",
       " {'id': '116972',\n",
       "  'title': 'Discovering the Real World of Harry Potter',\n",
       "  'score': 2.0560164},\n",
       " {'id': '672',\n",
       "  'title': 'Harry Potter and the Chamber of Secrets',\n",
       "  'score': 1.8415655},\n",
       " {'id': '671',\n",
       "  'title': \"Harry Potter and the Philosopher's Stone\",\n",
       "  'score': 1.8207769},\n",
       " {'id': '54507', 'title': 'A Very Potter Musical', 'score': 1.80214}]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "request = {\"query\": \"harry potter\",\n",
    "           \"query_fields\": [\"title\", \"overview\"],\n",
    "           \"return_fields\": [\"title\", \"id\", \"score\"]}\n",
    "response = ltr.search_with_model(\"movie_model\", **request)\n",
    "print(\"\\nReturned Documents:\")\n",
    "display(response[\"docs\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Listing 10.18 - A search query utilizing a baseline search and rerank utilizing the model\n",
    "\n",
    "Issuing a simple lexical query followed by a rerank the top 500 documents using the LTR model (optimized)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Returned Documents:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'id': '570724', 'title': 'The Story of Harry Potter', 'score': 2.1900952},\n",
       " {'id': '116972',\n",
       "  'title': 'Discovering the Real World of Harry Potter',\n",
       "  'score': 2.0560164},\n",
       " {'id': '672',\n",
       "  'title': 'Harry Potter and the Chamber of Secrets',\n",
       "  'score': 1.8415655},\n",
       " {'id': '671',\n",
       "  'title': \"Harry Potter and the Philosopher's Stone\",\n",
       "  'score': 1.8207769},\n",
       " {'id': '54507', 'title': 'A Very Potter Musical', 'score': 1.80214}]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "request = {\"query\": \"harry potter\",\n",
    "           \"query_fields\": [\"title\", \"overview\"],\n",
    "           \"return_fields\": [\"title\", \"id\", \"score\"],\n",
    "           \"rerank_query\": \"harry potter\",\n",
    "           \"rerank_count\": 500}\n",
    "response = ltr.search_with_model(\"movie_model\", **request)\n",
    "print(\"\\nReturned Documents:\")\n",
    "display(response[\"docs\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rinse and repeat!\n",
    "\n",
    "What would you change about this model or the features used? Maybe revisit [the features](2.judgments-and-logging.ipynb) to explore some different ideas?\n",
    "\n",
    "Up next: [Chapter 11: Automating Learning to Rank with Click Models](../ch11/0.setup.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
